#!venv/bin/python

"""
Splits and samples any line-delimited dataset. This only supports absolute
sample sizes and not proportions. To make things easy, this does load the entire
input dataset into memory, so there is a limit on dataset size.
"""

import argparse
import random


def sample_pair(x):
    try:
        name, size = x.split(',')
        return name, int(size)
    except:
        raise argparse.ArgumentTypeError(f"Tuples must be name(string),size(int): {x}")


def split_sample(input_file, outputs):
    """
    Loads the input dataset, shuffles, then iterates through the outputs and
    outputs the number of samples/size to file.
    """

    with open(input_file, 'r') as f:
        dataset = f.read().splitlines()
    random.shuffle(dataset)

    # ensure we don't EOF
    total_size = sum([size for _, size in outputs])
    assert total_size <= len(dataset), "Sum of sample sizes must be less than or equal to the dataset size."

    start = 0
    for out_file, size in outputs:
        end = start + size
        with open(out_file, 'w') as out:
            for line in dataset[start:end]:
                out.write(f"{line}\n")
        start = end


def main():
    parser = argparse.ArgumentParser(prog='split-and-sample')
    parser.add_argument('--input', required=True, help="Input file")
    parser.add_argument('--outputs', type=sample_pair, nargs='+',
                        help="Output files and sample sizes as list, e.g. out1,10 out2,100")
    args = parser.parse_args()

    split_sample(args.input, args.outputs)


if __name__ == "__main__":
    main()
